{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training a RL Agent with Stable-Baselines3 Using a GEM Environment\n",
    "\n",
    "This notebook serves as an educational introduction to the usage of Stable-Baselines3 using a gym-electric-motor (GEM) environment. The goal of this notebook is to give an understanding of what Stable-Baselines3 is and how to use it to train and evaluate a reinforcement learning agent that can solve a current control problem of the GEM toolbox."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Installation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before you can start you need to make sure that you have both gym-electric-motor and Stable-Baselines3 installed. You can install both easily using pip:\n",
    "\n",
    "- ```pip install gym-electric-motor```\n",
    "- ```pip install stable-baselines3```\n",
    "\n",
    "Alternatively, you can install them and their latest developer version directly from GitHub:\n",
    "\n",
    "- https://github.com/upb-lea/gym-electric-motor\n",
    "- https://github.com/DLR-RM/stable-baselines3\n",
    "\n",
    "For this notebook, the following cell will do the job:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q git+https://github.com/upb-lea/gym-electric-motor.git git+https://github.com/DLR-RM/stable-baselines3.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Setting up a GEM Environment\n",
    "\n",
    "The basic idea behind reinforcement learning is to create a so-called agent, that should learn by itself to solve a specified task in a given environment. \n",
    "This environment gives the agent feedback on its actions and reinforces the targeted behavior.\n",
    "In this notebook, the task is to train a controller for the current control of a *permanent magnet synchronous motor* (*PMSM*).\n",
    " \n",
    "In the following, the used GEM-environment is briefly presented, but this notebook does not focus directly on the detailed usage of GEM. If you are new to the used environment and interested in finding out what it does and how to use it, you should take a look at the [GEM cookbook](https://colab.research.google.com/github/upb-lea/gym-electric-motor/blob/master/examples/example_notebooks/GEM_cookbook.ipynb).\n",
    "\n",
    "The basic idea of the control setup from the GEM-environment is displayed in the following figure. \n",
    "\n",
    "![](../../docs/plots/SCML_Overview.png)\n",
    "\n",
    "The agent controls the converter who converts the supply currents to the currents flowing into the motor - for the *PMSM*: $i_{sq}$ and $i_{sd}$\n",
    "\n",
    "In the continuous case, the agent's action equals a duty cycle which will be modulated into a corresponding voltage. \n",
    "\n",
    "In the discrete case, the agent's actions denote switching states of the converter at the given instant. Here, only a discrete amount of options are available. In this notebook, for the PMSM the *discrete B6 bridge converter* with six switches is utilized per default. This converter provides a total of eight possible actions.\n",
    "\n",
    "![](../../docs/plots/B6.svg)\n",
    "\n",
    "The motor schematic is the following:\n",
    "\n",
    "\n",
    "![](../../docs/plots/ESBdq.svg)\n",
    "\n",
    "And the electrical ODEs for that motor are:\n",
    "\n",
    "<h3 align=\"center\">\n",
    "\n",
    "<!-- $\\frac{\\mathrm{d}i_{sq}}{\\mathrm{d}t} = \\frac{u_{sq}-pL_d\\omega_{me}i_{sd}-R_si_{sq}}{L_q}$\n",
    "\n",
    "$\\frac{\\mathrm{d}i_{sd}}{\\mathrm{d}t} = \\frac{u_{sd}-pL_q\\omega_{me}i_{sq}-R_si_{sd}}{L_d}$\n",
    "\n",
    "$\\frac{\\mathrm{d}\\epsilon_{el}}{\\mathrm{d}t} = p\\omega_{me}$\n",
    " -->\n",
    "\n",
    "   $ \\frac{\\mathrm{d}i_{sd}}{\\mathrm{d}t}=\\frac{u_{sd} + p\\omega_{me}L_q i_{sq} - R_s i_{sd}}{L_d} $ <br><br>\n",
    "    $\\frac{\\mathrm{d} i_{sq}}{\\mathrm{d} t}=\\frac{u_{sq} - p \\omega_{me} (L_d i_{sd} + \\mathit{\\Psi}_p) - R_s i_{sq}}{L_q}$ <br><br>\n",
    "   $\\frac{\\mathrm{d}\\epsilon_{el}}{\\mathrm{d}t} = p\\omega_{me}$\n",
    "\n",
    "</h3>\n",
    "\n",
    "The target for the agent is now to learn to control the currents. For this, a reference generator produces a trajectory that the agent has to follow. \n",
    "Therefore, it has to learn a function (policy) from given states, references and rewards to appropriate actions.\n",
    "\n",
    "For a deeper understanding of the used models behind the environment see the [documentation](https://upb-lea.github.io/gym-electric-motor/).\n",
    "Comprehensive learning material to RL is also [freely available](https://github.com/upb-lea/reinforcement_learning_course_materials)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib notebook\n",
    "#%matplotlib widget\n",
    "# Use %matplotlib widget in Visual Studio Code\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import gym_electric_motor as gem\n",
    "from gym_electric_motor.reference_generators import \\\n",
    "    MultipleReferenceGenerator,\\\n",
    "    WienerProcessReferenceGenerator\n",
    "from gym_electric_motor.visualization import MotorDashboard\n",
    "from gym_electric_motor.core import Callback\n",
    "from gym_electric_motor.envs.motors import ActionType, ControlType, Motor, MotorType\n",
    "from gym_electric_motor.physical_systems.solvers import EulerSolver\n",
    "from gymnasium.spaces import Discrete, Box\n",
    "from gymnasium.wrappers import FlattenObservation, TimeLimit\n",
    "from gymnasium import ObservationWrapper\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# helper functions and classes\n",
    "\n",
    "class RewardLogger(Callback):\n",
    "    \"\"\"Logs the reward accumulated in each episode\"\"\"\n",
    "    def __init__(self):\n",
    "        self.step_rewards = []\n",
    "        self.mean_episode_rewards = []\n",
    "        dir_path = Path.cwd() /\"rl_frameworks\" / \"saved_agents\"\n",
    "        dir_path.mkdir(parents=True, exist_ok=True)\n",
    "        self.fpath = dir_path  / \"EpisodeRewards.npy\"\n",
    "        \n",
    "    def on_step_end(self, k, state, reference, reward, done):\n",
    "        \"\"\"Stores the received reward at each step\"\"\"\n",
    "        self.step_rewards.append(reward)\n",
    "    \n",
    "    def on_reset_begin(self):\n",
    "        \"\"\"Stores the mean reward received in every episode\"\"\"\n",
    "        if len(self.step_rewards) > 0:\n",
    "            self.mean_episode_rewards.append(np.mean(self.step_rewards))\n",
    "        self.step_rewards = []\n",
    "        \n",
    "    def on_close(self):\n",
    "        \"\"\"Writes the mean episode reward of the experiment to a file.\"\"\"\n",
    "        np.save(self.fpath, np.array(self.mean_episode_rewards))\n",
    "\n",
    "\n",
    "class FeatureWrapper(ObservationWrapper):\n",
    "    \"\"\"\n",
    "    Wrapper class which wraps the environment to change its observation. Serves\n",
    "    the purpose to improve the agent's learning speed.\n",
    "    \n",
    "    It changes epsilon to cos(epsilon) and sin(epsilon). This serves the purpose\n",
    "    to have the angles -pi and pi close to each other numerically without losing\n",
    "    any information on the angle.\n",
    "    \n",
    "    Additionally, this wrapper adds a new observation i_sd**2 + i_sq**2. This should\n",
    "    help the agent to easier detect incoming limit violations.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, env, epsilon_idx, i_sd_idx, i_sq_idx):\n",
    "        \"\"\"\n",
    "        Changes the observation space to fit the new features\n",
    "        \n",
    "        Args:\n",
    "            env(GEM env): GEM environment to wrap\n",
    "            epsilon_idx(integer): Epsilon's index in the observation array\n",
    "            i_sd_idx(integer): I_sd's index in the observation array\n",
    "            i_sq_idx(integer): I_sq's index in the observation array\n",
    "        \"\"\"\n",
    "        super(FeatureWrapper, self).__init__(env)\n",
    "        self.EPSILON_IDX = epsilon_idx\n",
    "        self.I_SQ_IDX = i_sq_idx\n",
    "        self.I_SD_IDX = i_sd_idx\n",
    "        new_low = np.concatenate((self.env.observation_space.low[\n",
    "                                  :self.EPSILON_IDX], np.array([-1.]),\n",
    "                                  self.env.observation_space.low[\n",
    "                                  self.EPSILON_IDX:], np.array([0.])))\n",
    "        new_high = np.concatenate((self.env.observation_space.high[\n",
    "                                   :self.EPSILON_IDX], np.array([1.]),\n",
    "                                   self.env.observation_space.high[\n",
    "                                   self.EPSILON_IDX:],np.array([1.])))\n",
    "\n",
    "        self.observation_space = Box(new_low, new_high)\n",
    "\n",
    "    def observation(self, observation):\n",
    "        \"\"\"\n",
    "        Gets called at each return of an observation. Adds the new features to the\n",
    "        observation and removes original epsilon.\n",
    "        \n",
    "        \"\"\"\n",
    "        cos_eps = np.cos(observation[self.EPSILON_IDX] * np.pi)\n",
    "        sin_eps = np.sin(observation[self.EPSILON_IDX] * np.pi)\n",
    "        currents_squared = observation[self.I_SQ_IDX]**2 + observation[self.I_SD_IDX]**2\n",
    "        observation = np.concatenate((observation[:self.EPSILON_IDX],\n",
    "                                      np.array([cos_eps, sin_eps]),\n",
    "                                      observation[self.EPSILON_IDX + 1:],\n",
    "                                      np.array([currents_squared])))\n",
    "        return observation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define motor arguments\n",
    "motor_parameter = dict(\n",
    "    p=3,  # [p] = 1, nb of pole pairs\n",
    "    r_s=17.932e-3,  # [r_s] = Ohm, stator resistance\n",
    "    l_d=0.37e-3,  # [l_d] = H, d-axis inductance\n",
    "    l_q=1.2e-3,  # [l_q] = H, q-axis inductance\n",
    "    psi_p=65.65e-3,  # [psi_p] = Vs, magnetic flux of the permanent magnet\n",
    ")\n",
    "# supply voltage\n",
    "u_supply = 350\n",
    "\n",
    "# nominal and absolute state limitations\n",
    "nominal_values=dict(\n",
    "    omega=4000*2*np.pi/60,\n",
    "    i=230,\n",
    "    u=u_supply\n",
    ")\n",
    "limit_values=dict(\n",
    "    omega=4000*2*np.pi/60,\n",
    "    i=1.5*230,\n",
    "    u=u_supply\n",
    ")\n",
    "\n",
    "# sampling interval\n",
    "tau = 1e-5\n",
    "\n",
    "# define maximal episode steps\n",
    "max_eps_steps = 10000\n",
    "\n",
    "\n",
    "motor_initializer={'random_init': 'uniform', 'interval': [[-230, 230], [-230, 230], [-np.pi, np.pi]]}\n",
    "reward_function=gem.reward_functions.WeightedSumOfErrors(\n",
    "    reward_weights={'i_sq': 10, 'i_sd': 10},\n",
    "    gamma=0.99,  # discount rate \n",
    "    reward_power=1\n",
    ")\n",
    "reward_logger = RewardLogger()\n",
    "motor_dashboard = MotorDashboard(state_plots=['i_sq', 'i_sd'], reward_plot=True)\n",
    "motor = Motor(MotorType.PermanentMagnetSynchronousMotor,\n",
    "              ControlType.CurrentControl,\n",
    "              ActionType.Finite)\n",
    "#solver = EulerSolver()\n",
    "# creating gem environment\n",
    "env = gem.make(  # define a PMSM with discrete action space\n",
    "    #\"Finite-CC-PMSM-v0\",\n",
    "    motor.env_id(),\n",
    "    # visualize the results\n",
    "    visualization=motor_dashboard,\n",
    "    \n",
    "    # parameterize the PMSM and update limitations\n",
    "    motor=dict(\n",
    "        motor_parameter=motor_parameter,\n",
    "        limit_values=limit_values,\n",
    "        nominal_values=nominal_values,\n",
    "        motor_initializer=motor_initializer,\n",
    "    ),\n",
    "    # define the random initialisation for load and motor\n",
    "    load=dict(\n",
    "        load_initializer={'random_init': 'uniform', },\n",
    "    ),\n",
    "    reward_function=reward_function,\n",
    "    supply=dict(u_nominal=u_supply),\n",
    "    # define the duration of one sampling step\n",
    "    tau=tau,\n",
    "    callbacks=(reward_logger,),\n",
    "    ode_solver=EulerSolver(),\n",
    ")\n",
    "\n",
    "# remove one action from the action space to help the agent speed up its training\n",
    "# this can be done as both switchting states (1,1,1) and (-1,-1,-1) - which are encoded\n",
    "# by action 0 and 7 - both lead to the same zero voltage vector in alpha/beta-coordinates\n",
    "env.action_space = Discrete(7)\n",
    "\n",
    "# applying wrappers\n",
    "eps_idx = env.physical_system.state_names.index('epsilon')\n",
    "i_sd_idx = env.physical_system.state_names.index('i_sd')\n",
    "i_sq_idx = env.physical_system.state_names.index('i_sq')\n",
    "env = TimeLimit(\n",
    "    FeatureWrapper(\n",
    "        FlattenObservation(env), \n",
    "        eps_idx, i_sd_idx, i_sq_idx\n",
    "    ),\n",
    "    max_eps_steps\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Training an Agent with Stable-Baselines3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Stable-Baselines3 collects Reinforcement Learning algorithms implemented in Pytorch. \n",
    "\n",
    "Stable-Baselines3 is still a very new library with its current release being 0.9. That is why its collection of algorithms is not very large yet and most algorithms lack more advanced variants. However, its authors planned to broaden the available algorithms in the future. For currently available algorithms see their [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/rl.html).\n",
    "\n",
    "To use an agent provided by Stable-Baselines3 your environment has to have a [gym interface](https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1 Imports"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The environment in this control problem poses a discrete action space. Therefore, the [Deep Q-Network (DQN)](https://arxiv.org/abs/1312.5602) is a suitable agent.\n",
    "For the specific implementation of the DQN you can refer to [Stable-Baslines3's docs](https://stable-baselines3.readthedocs.io/en/master/modules/dqn.html).\n",
    "\n",
    "In this tutorial a multi-layer perceptron (MLP) is used. For this you have to import the DQN and the MlpPolicy. You can also see in the docs which gym spaces for the observation and the actions are supported. You might have to take that into account for your environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from stable_baselines3 import DQN\n",
    "from stable_baselines3.dqn import MlpPolicy\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Parameterization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the DQN algorithm you have to define a set of parameters. The policy_kwargs dictionary is a parameter which is directly given to the MlpPolicy. The net_arch key defines the network architecture of the MLP."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "buffer_size = 200000 #number of old obsersation steps saved\n",
    "learning_starts = 10000 # memory warmup\n",
    "train_freq = 1 # prediction network gets an update each train_freq's step\n",
    "batch_size = 25 # mini batch size drawn at each update step\n",
    "policy_kwargs = {\n",
    "        'net_arch': [64,64] # hidden layer size of MLP\n",
    "        }\n",
    "exploration_fraction = 0.1 # Fraction of training steps the epsilon decays \n",
    "target_update_interval = 1000 # Target network gets updated each target_update_interval's step\n",
    "gamma = 0.99\n",
    "verbose = 1 # verbosity of stable-basline's prints"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Additionally, you have to define how long your agent shall train. You can just set a concrete number of steps or use knowledge of the environment's temporal resolution to define an in-simulation training time. In this example, the agent is trained for five seconds which translates in this environment's case to 500000 steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tau = 1e-5\n",
    "simulation_time = 5 # seconds\n",
    "nb_steps = int(simulation_time // tau)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 Training of the Agent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once you've setup the environment and defined your parameters starting the training is nothing more than a one-liner. For each algorithm all you have to do is call its ```learn()``` function. However, you should note that the execution of the training can take a long time. Currently, Stable-Baselines3 does not provide any means of saving the training reward for later visualization. Therefore, a ```RewardLogger``` callback is used for this environment (see code a few cells above)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = DQN(MlpPolicy, env, buffer_size=buffer_size, learning_starts=learning_starts ,train_freq=train_freq, \n",
    "            batch_size=batch_size, gamma=gamma, policy_kwargs=policy_kwargs, \n",
    "            exploration_fraction=exploration_fraction, target_update_interval=target_update_interval,\n",
    "            verbose=verbose)\n",
    "model.learn(total_timesteps=nb_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.4 Saving the Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When the training has finished you can save the model your DQN has learned to reuse it later, e.g. for evaluation or if you want to continue your training. For this, each Stable-Baselines3 algorithm has a ```.save()``` function where you only have to specify your path."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_path = Path.cwd() / \"saved_agents\" \n",
    "log_path.mkdir(parents=True, exist_ok=True)\n",
    "model.save(str(log_path / \"TutorialAgent\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Evaluating an Agent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After you have trained your agent you would like to see how well it does on your control problem. For this you can look at a visual representation of your currents in a test trajectory or see how well your agent does in a test scenario."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1 Loading a Model (Optional)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, before you start your evaluation you have to load a trained agent or take the trained agent from above which is still saved in the variable ```model```. To load a trained agent you simply have to call the ```load()``` function of your algorithm with the respective path."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DQN.load(log_path / 'TutorialAgent')  #your agent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 Taking a Look at the Mean Reward per Episode During Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The ```RewardLogger``` callback saved the mean reward per episode during training. There you can observe how the training reward did grow over time and if any problems occured during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards = reward_logger.mean_episode_rewards # training rewards\n",
    "plt.figure()\n",
    "plt.grid(True)\n",
    "plt.xlim(0,len(rewards))\n",
    "plt.ylim(min(rewards), 0)\n",
    "plt.yticks(np.arange(min(rewards), 1, 1.0))\n",
    "plt.tick_params(axis = 'y', left = False, labelleft = False)\n",
    "plt.xticks(np.arange(0, len(rewards), 10))\n",
    "plt.xlabel('#Episode')\n",
    "plt.ylabel('Mean Reward per Episode (Qualitative)')\n",
    "plt.plot(rewards)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.3 Test Trajectory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can take a look at a test trajectory to see how well your trained agent is able to control the currents to follow the test trajectory. For the agent to decide for an action given an observation you can just call its ```predict()``` function. The key deterministic is important so that the agent is not using a stochastic policy like epsilon greedy but is instead chosing an action greedily. The ```env.render()``` will then visualize the agent's and reference generator's trajectories as well as the reward."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "motor_dashboard.initialize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualization_steps = int(9e4)\n",
    "obs, _ = env.reset()\n",
    "for i in range(visualization_steps):\n",
    "    action, _states = model.predict(obs, deterministic=True)\n",
    "    obs, reward, done, trun, _ = env.step(action)\n",
    "    env.render()\n",
    "    if done or trun:\n",
    "        obs, _ = env.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.4 Calculating further evaluation parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the knowledge you acquired in the previous sections you are now able to train and evaluate any in Stable-Baselines3 available reinforcement learning algorithm. The code below should give you an example how to use the trained agent to calculate a mean reward and mean episode length over a specific amount of steps. For further questions you can always have a look at the documentation of gym-electric-motor and Stable-Baselines3 or raise an issue in their respective GitHub repositories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_steps = int(1e5)\n",
    "cum_rew = 0\n",
    "episode_step = 0\n",
    "episode_lengths = []\n",
    "obs, _ = env.reset()\n",
    "for i in range(test_steps):\n",
    "    print(f\"{i+1} / {test_steps}\", end = '\\r')\n",
    "    episode_step += 1\n",
    "    action, _states = model.predict(obs, deterministic=True)\n",
    "    obs, reward, done, trun, _ = env.step(action)\n",
    "    cum_rew += reward\n",
    "    if done or trun:\n",
    "        episode_lengths.append(episode_step)\n",
    "        episode_step = 0\n",
    "        obs, _ = env.reset()\n",
    "print(f\"The reward per step with {test_steps} steps was: {cum_rew/test_steps:.4f} \")\n",
    "print(f\"The average Episode length was: {round(np.mean(episode_lengths))} \")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "GEM2v0",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
